﻿#include "ContactHandling.cginc"
#include "ColliderDefinitions.cginc"
#include "Rigidbody.cginc"
#include "Simplex.cginc"
#include "CollisionMaterial.cginc"
#include "AtomicDeltas.cginc"

#pragma kernel Clear
#pragma kernel Initialize
#pragma kernel Project
#pragma kernel Apply

StructuredBuffer<int> particleIndices;

StructuredBuffer<int> simplices;
StructuredBuffer<float> invMasses;
StructuredBuffer<float> invRotationalMasses;
StructuredBuffer<float4> prevPositions;
StructuredBuffer<float4> prevOrientations;
StructuredBuffer<float4> principalRadii;
StructuredBuffer<float4> velocities;

StructuredBuffer<transform> transforms;
StructuredBuffer<shape> shapes;
RWStructuredBuffer<rigidbody> RW_rigidbodies;

RWStructuredBuffer<float4> positions;
RWStructuredBuffer<float4> orientations;
RWStructuredBuffer<float4> deltas;

RWStructuredBuffer<contact> contacts;
RWStructuredBuffer<contactMasses> effectiveMasses;
StructuredBuffer<uint> dispatchBuffer;

StructuredBuffer<inertialFrame> inertialSolverFrame;

// Variables set from the CPU
uint particleCount;
float maxDepenetration;
float substepTime;
float stepTime;
int steps;
float timeLeft;
float sorFactor;

[numthreads(128, 1, 1)]
void Clear (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
   
    if (i >= dispatchBuffer[3]) return;

    int rigidbodyIndex = shapes[contacts[i].bodyB].rigidbodyIndex;
    if (rigidbodyIndex >= 0)
    {
        int orig;
        InterlockedExchange(RW_rigidbodies[rigidbodyIndex].constraintCount, 0, orig);
    }
}

[numthreads(128, 1, 1)]
void Initialize (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= dispatchBuffer[3]) return;

    int simplexSizeA;
    int simplexStartA = GetSimplexStartAndSize(contacts[i].bodyA, simplexSizeA);

    // get the material from the first particle in the simplex:
    int aMaterialIndex = collisionMaterialIndices[simplices[simplexStartA]];
    bool rollingContacts = aMaterialIndex >= 0 ? collisionMaterials[aMaterialIndex].rollingContacts > 0 : false;

    float4 relativeVelocity = FLOAT4_ZERO;
    float4 simplexPrevPosition = FLOAT4_ZERO;
    quaternion simplexPrevOrientation = quaternion(0, 0, 0, 0);
    float simplexRadius = 0;
    float simplexInvMass = 0;
    float simplexInvRotationalMass = 0;

    for (int j = 0; j < simplexSizeA; ++j)
    {
        int particleIndex = simplices[simplexStartA + j];
        relativeVelocity += velocities[particleIndex] * contacts[i].pointA[j];
        simplexPrevPosition += prevPositions[particleIndex] * contacts[i].pointA[j];
        simplexPrevOrientation += prevOrientations[particleIndex] * contacts[i].pointA[j];
        simplexInvMass += invMasses[particleIndex] * contacts[i].pointA[j];
        simplexInvRotationalMass += invRotationalMasses[particleIndex] * contacts[i].pointA[j];
        simplexRadius += EllipsoidRadius(contacts[i].normal, prevOrientations[particleIndex], principalRadii[particleIndex].xyz) * contacts[i].pointA[j];
    }

    // if there's a rigidbody present, subtract its velocity from the relative velocity:
    int rigidbodyIndex = shapes[contacts[i].bodyB].rigidbodyIndex;
    if (rigidbodyIndex >= 0)
    {
        // Note: unlike rA, that is expressed in solver space, rB is expressed in world space.
        relativeVelocity -= GetRigidbodyVelocityAtPoint(RW_rigidbodies[rigidbodyIndex], contacts[i].pointB, 
                                                        asfloat(linearDeltasAsInt[rigidbodyIndex]), 
                                                        asfloat(angularDeltasAsInt[rigidbodyIndex]), inertialSolverFrame[0]);

        int bMaterialIndex = shapes[contacts[i].bodyB].materialIndex;
        rollingContacts = rollingContacts | (bMaterialIndex >= 0 ? collisionMaterials[bMaterialIndex].rollingContacts > 0 : false);
    }

    // update contact distance
    contacts[i].dist = dot(simplexPrevPosition - contacts[i].pointB, contacts[i].normal) - simplexRadius;

    // calculate contact point in A's surface:
    float4 contactPoint = contacts[i].pointB +  contacts[i].normal * contacts[i].dist;

    // update contact basis:
    CalculateBasis(relativeVelocity, contacts[i].normal, contacts[i].tangent);

    // calculate A's contact mass.
    float4 invInertiaTensor = 1.0/(GetParticleInertiaTensor(simplexRadius, simplexInvRotationalMass) + FLOAT4_EPSILON);
    CalculateContactMassesA(simplexInvMass, invInertiaTensor, simplexPrevPosition, simplexPrevOrientation, contactPoint, rollingContacts, contacts[i].normal, contacts[i].tangent, GetBitangent(contacts[i]), effectiveMasses[i].normalInvMassA, effectiveMasses[i].tangentInvMassA, effectiveMasses[i].bitangentInvMassA);
   
    // clear B's contact mass.
    if (rigidbodyIndex >= 0)
    {
        CalculateContactMassesB(RW_rigidbodies[rigidbodyIndex], inertialSolverFrame[0].frame, contacts[i].pointB, contacts[i].normal, contacts[i].tangent, GetBitangent(contacts[i]), effectiveMasses[i].normalInvMassB, effectiveMasses[i].tangentInvMassB, effectiveMasses[i].bitangentInvMassB);
        InterlockedAdd(RW_rigidbodies[rigidbodyIndex].constraintCount, 1);
    }
    else
    {
        ClearContactMasses(effectiveMasses[i].normalInvMassB, effectiveMasses[i].tangentInvMassB, effectiveMasses[i].bitangentInvMassB);
    }
}

[numthreads(128, 1, 1)]
void Project (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= dispatchBuffer[3]) return;

    // Skip contacts involving triggers:
    if (shapes[contacts[i].bodyB].isTrigger())
        return;

    int simplexSize;
    int simplexStart = GetSimplexStartAndSize(contacts[i].bodyA, simplexSize);
    int colliderIndex = contacts[i].bodyB;

    // Get the rigidbody index (might be < 0, in that case there's no rigidbody present)
    int rigidbodyIndex = shapes[colliderIndex].rigidbodyIndex;

    float frameEnd = stepTime * steps;
    float substepsToEnd = timeLeft / substepTime;

    // Combine collision materials (use material from first particle in simplex)
    collisionMaterial material = CombineCollisionMaterials(collisionMaterialIndices[simplices[simplexStart]], shapes[colliderIndex].materialIndex);

    // Get relative velocity at contact point.
    // As we do not consider true ellipses for collision detection, particle contact points are never off-axis.
    // So particle angular velocity does not contribute to normal impulses, and we can skip it.
    float4 simplexPosition = FLOAT4_ZERO;
    float4 simplexPrevPosition = FLOAT4_ZERO;
    float simplexRadius = 0;

    for (int j = 0; j < simplexSize; ++j)
    {
        int particleIndex = simplices[simplexStart + j];
        simplexPosition += positions[particleIndex] * contacts[i].pointA[j];
        simplexPrevPosition += prevPositions[particleIndex] * contacts[i].pointA[j];
        simplexRadius += EllipsoidRadius(contacts[i].normal, orientations[particleIndex], principalRadii[particleIndex].xyz) * contacts[i].pointA[j];
    }

    // project position to the end of the full step:
    float4 posA = lerp(simplexPrevPosition, simplexPosition, substepsToEnd);
    posA += -contacts[i].normal * simplexRadius;

    float4 posB = contacts[i].pointB;
    int rbContacts = 1;
    if (rigidbodyIndex >= 0)            
    {
        posB += GetRigidbodyVelocityAtPoint(rigidbodies[rigidbodyIndex], contacts[i].pointB, 
                                                        asfloat(linearDeltasAsInt[rigidbodyIndex]), 
                                                        asfloat(angularDeltasAsInt[rigidbodyIndex]), inertialSolverFrame[0]) * frameEnd;
        rbContacts = rigidbodies[rigidbodyIndex].constraintCount;
    }

    float normalInvMass = effectiveMasses[i].normalInvMassA + effectiveMasses[i].normalInvMassB * rbContacts;
    float lambda = SolveAdhesion(contacts[i], normalInvMass, posA, posB, material.stickDistance, material.stickiness, stepTime);

    lambda += SolvePenetration(contacts[i], normalInvMass, posA, posB, maxDepenetration * stepTime);

    if (abs(lambda) > EPSILON)
    {
        float4 delta = lambda * contacts[i].normal * BaryScale(contacts[i].pointA) / substepsToEnd;
        for (int j = 0; j < simplexSize; ++j)
        {
            int particleIndex = simplices[simplexStart + j];
            float4 delta1 = delta * invMasses[particleIndex] * contacts[i].pointA[j];

            AtomicAddPositionDelta(particleIndex, delta1);
        }

        if (rigidbodyIndex >= 0)
        {
            ApplyImpulse(rigidbodyIndex, -lambda / frameEnd * contacts[i].normal, contacts[i].pointB, inertialSolverFrame[0].frame);
        }
    }
}

[numthreads(128, 1, 1)]
void Apply (uint3 id : SV_DispatchThreadID) 
{
    unsigned int threadIndex = id.x;

    if (threadIndex >= particleCount) return;

    int p = particleIndices[threadIndex];
    
    ApplyPositionDelta(positions, p, sorFactor);
}


